"""
Created on Fri Nov 20 15:07:36 2020

@author: lopez
"""

# This block of lines import useful functions and libraries.
import glob
import numpy as np 
from skimage import io, img_as_float, filters
from skimage import img_as_float32
import tkinter as tk
from tkinter import filedialog
import os
import tifffile as tf

"""This script analyses all the images contained in a folder in order 
to threshold only the brightest components in the image.This is done 
using a Multi-Otsu algorithm"""

# Creates a dialog window to obtain the folder where the wildtype images are.
root = tk.Tk()
root.withdraw()
folder_selected = filedialog.askdirectory(title='Select the folder that contains the wildtype images.')

# Opens each one of the images.
for filename in glob.glob(os.path.join(folder_selected, '*.tif')):
    
    # Gets the name of the file that contains img.
    string1 = os.path.splitext(filename)[0]
    file_name1 = os.path.basename(string1) # This gets the actual filename. 
    
    # Reads the image.
    img = io.imread(filename)
    
    # Transforms the image into a float image.
    img = img_as_float(img)
    
    # Gets the number of frames.
    frames = img.shape[0]
    
    # Gets the microns per pixel.
    with tf.TiffFile(filename) as tif:
        tif_tags = {}
        for tag in tif.pages[0].tags.values():
            name, value = tag.name, tag.value
            tif_tags[name] = value
    microns_per_pixel = 1 / (1e-6 * tif_tags['XResolution'][0])
    
    # Gets the time between frames in seconds.
    text = tif_tags['ImageDescription']
    parsing = text.split('\n')
    time_between_frames = float(parsing[4].split('=')[1])
    
    # Creates a series of binary images by applying Multi-Otsu thresholding to the time series.
    thresh = filters.threshold_multiotsu(img) # This obtains two Otsu thresholds for two levels of intensity in the image.
    region = np.digitize(img, thresh)
    
    # Creates a masked image that only shows the brightest regions (those  where region == 2).
    masked_img = np.zeros_like(img)
    for i in range(img.shape[0]):
        temporary = np.copy(img[i,:,:])
        mask = (region[i,:,:] == 0) | (region[i,:,:] == 1)
        temporary[mask] = 0
        masked_img[i,:,:] = temporary
    
     # Saves the masked images as a tiff time trace.
    masked_img_tiff = np.zeros((masked_img.shape[0],1,1,masked_img.shape[1],masked_img.shape[2],1))
    for i in range(masked_img.shape[0]):
        masked_img_tiff[i,0,0,:,:,0] = masked_img[i,:,:]
    masked_img_tiff = img_as_float32(np.copy(masked_img_tiff))
    tf.imwrite(file_name1+'_for_imageJ.tiff', masked_img_tiff, imagej=True, resolution=((1/microns_per_pixel), (1/microns_per_pixel)), metadata={'spacing': 1.0, 'unit': 'micron'})
    
  
